{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Batch Normalization in `gluon`\n",
    "\n",
    "In the preceding section, [we implemented batch normalization ourselves](../chapter04_convolutional-neural-networks/cnn-batch-norm-scratch.ipynb) using NDArray and autograd.\n",
    "As with most commonly used neural network layers,\n",
    "Gluon has batch normalization predefined,\n",
    "so this section is going to be straightforward."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-18T03:21:49.174951Z",
     "start_time": "2017-10-18T03:21:48.205450Z"
    }
   },
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import mxnet as mx\n",
    "from mxnet import nd, autograd\n",
    "from mxnet import gluon\n",
    "import numpy as np\n",
    "mx.random.seed(1)\n",
    "ctx = mx.cpu()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The MNIST dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-18T03:21:50.220488Z",
     "start_time": "2017-10-18T03:21:49.176860Z"
    }
   },
   "outputs": [],
   "source": [
    "batch_size = 64\n",
    "num_inputs = 784\n",
    "num_outputs = 10\n",
    "def transform(data, label):\n",
    "    return nd.transpose(data.astype(np.float32), (2,0,1))/255, label.astype(np.float32)\n",
    "train_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=True, transform=transform),\n",
    "                                      batch_size, shuffle=True)\n",
    "test_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=False, transform=transform),\n",
    "                                     batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define a CNN with Batch Normalization\n",
    "\n",
    "To add batchnormalization to a ``gluon`` model defined with Sequential,\n",
    "we only need to add a few lines. \n",
    "Specifically, we just insert `BatchNorm` layers before the applying the ReLU activations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-18T03:21:50.292271Z",
     "start_time": "2017-10-18T03:21:50.222527Z"
    }
   },
   "outputs": [],
   "source": [
    "num_fc = 512\n",
    "net = gluon.nn.Sequential()\n",
    "with net.name_scope():\n",
    "    net.add(gluon.nn.Conv2D(channels=20, kernel_size=5))\n",
    "    net.add(gluon.nn.BatchNorm(axis=1, center=True, scale=True))\n",
    "    net.add(gluon.nn.Activation(activation='relu'))\n",
    "    net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))\n",
    "    \n",
    "    net.add(gluon.nn.Conv2D(channels=50, kernel_size=5))\n",
    "    net.add(gluon.nn.BatchNorm(axis=1, center=True, scale=True))\n",
    "    net.add(gluon.nn.Activation(activation='relu'))\n",
    "    net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))\n",
    "    \n",
    "    # The Flatten layer collapses all axis, except the first one, into one axis.\n",
    "    net.add(gluon.nn.Flatten())\n",
    "    \n",
    "    net.add(gluon.nn.Dense(num_fc))\n",
    "    net.add(gluon.nn.BatchNorm(axis=1, center=True, scale=True))\n",
    "    net.add(gluon.nn.Activation(activation='relu'))\n",
    "    \n",
    "    net.add(gluon.nn.Dense(num_outputs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parameter initialization\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-18T03:21:50.311368Z",
     "start_time": "2017-10-18T03:21:50.296296Z"
    }
   },
   "outputs": [],
   "source": [
    "net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Softmax cross-entropy Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-18T03:21:50.335025Z",
     "start_time": "2017-10-18T03:21:50.322603Z"
    }
   },
   "outputs": [],
   "source": [
    "softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-10-18T03:21:50.350590Z",
     "start_time": "2017-10-18T03:21:50.339939Z"
    }
   },
   "outputs": [],
   "source": [
    "trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': .1})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Write evaluation loop to calculate accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2017-10-18T03:21:48.241Z"
    }
   },
   "outputs": [],
   "source": [
    "def evaluate_accuracy(data_iterator, net):\n",
    "    acc = mx.metric.Accuracy()\n",
    "    for i, (data, label) in enumerate(data_iterator):\n",
    "        data = data.as_in_context(ctx)\n",
    "        label = label.as_in_context(ctx)\n",
    "        output = net(data)\n",
    "        predictions = nd.argmax(output, axis=1)\n",
    "        acc.update(preds=predictions, labels=label)\n",
    "    return acc.get()[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2017-10-18T03:21:48.244Z"
    }
   },
   "outputs": [],
   "source": [
    "epochs = 1\n",
    "smoothing_constant = .01\n",
    "\n",
    "for e in range(epochs):\n",
    "    for i, (data, label) in enumerate(train_data):\n",
    "        data = data.as_in_context(ctx)\n",
    "        label = label.as_in_context(ctx)\n",
    "        with autograd.record():\n",
    "            output = net(data)\n",
    "            loss = softmax_cross_entropy(output, label)\n",
    "        loss.backward()\n",
    "        trainer.step(data.shape[0])\n",
    "        \n",
    "        ##########################\n",
    "        #  Keep a moving average of the losses\n",
    "        ##########################\n",
    "        curr_loss = nd.mean(loss).asscalar()\n",
    "        moving_loss = (curr_loss if ((i == 0) and (e == 0)) \n",
    "                       else (1 - smoothing_constant) * moving_loss + (smoothing_constant) * curr_loss)\n",
    "        \n",
    "    test_accuracy = evaluate_accuracy(test_data, net)\n",
    "    train_accuracy = evaluate_accuracy(train_data, net)\n",
    "    print(\"Epoch %s. Loss: %s, Train_acc %s, Test_acc %s\" % (e, moving_loss, train_accuracy, test_accuracy))    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Next\n",
    "[Introduction to recurrent neural networks](../chapter05_recurrent-neural-networks/simple-rnn.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For whinges or inquiries, [open an issue on  GitHub.](https://github.com/zackchase/mxnet-the-straight-dope)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
